{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dfsDR_omdNea"
      },
      "source": [
        "# CodeGemma - Common use cases\n",
        "This notebook demonstrates the basic task that Gemma can solve by using the right prompting.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/CodeGemma/[CodeGemma_1]Common_use_cases.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FaqZItBdeokU"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "### Gemma setup\n",
        "\n",
        "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on kaggle.com.\n",
        "* Select a Colab runtime with sufficient resources to run\n",
        "  the Gemma 2B model.\n",
        "* Generate and configure a Kaggle username and an API key as Colab secrets.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CY2kGtsyYpHF"
      },
      "source": [
        "### Configure your credentials\n",
        "\n",
        "Add your Kaggle credentials to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`\n",
        "3. Copy/paste your username into `KAGGLE_USERNAME`\n",
        "3. Copy/paste your key into `KAGGLE_KEY`\n",
        "4. Toggle the buttons on the left to allow notebook access to the secrets.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A9sUQ4WrP-Yr"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iwjo5_Uucxkw"
      },
      "source": [
        "### Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r_nXPEsF7UWQ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m570.5/570.5 kB\u001b[0m \u001b[31m15.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m589.8/589.8 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m81.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m95.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m85.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.16.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "!pip install -q -U keras keras-nlp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pOAEiJmnBE0D"
      },
      "source": [
        "## Exploring prompting capabilities"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7CIQwvFT4IgN"
      },
      "source": [
        "### CodeGemma\n",
        "\n",
        "CodeGemma models are text-to-text and text-to-code decoder-only models that specializes in code completion and code generation tasks. The CodeGemma 2B and 7B variants are specially tuned for code infilling tasks.\n",
        "\n",
        "This example uses CodeGemma's fill-in-the-middle (FIM) capability to complete code based on the surrounding context. This is particularly useful in code editor applications for inserting code where the text cursor is based on the code around it (before and after the cursor).\n",
        "\n",
        "CodeGemma lets you use 4 user-defined tokens - 3 (`<|fim_prefix|>`, `<|fim_suffix|>`, `<|fim_middle|>`) for FIM and a `<|file_separator|>` token for multi-file context support. Use these to define constants. More about that in the next cells.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RTug97S93rs8"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import keras\n",
        "import keras_nlp\n",
        "from google.colab import userdata\n",
        "\n",
        "keras.config.set_floatx(\"bfloat16\")\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gq_ZwkLQlYgt"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_1.1_2b_en/1/download/metadata.json...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 143/143 [00:00<00:00, 126kB/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_1.1_2b_en/1/download/config.json...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 501/501 [00:00<00:00, 428kB/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_1.1_2b_en/1/download/model.weights.h5...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4.67G/4.67G [02:25<00:00, 34.4MB/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_1.1_2b_en/1/download/tokenizer.json...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 315/315 [00:00<00:00, 413kB/s]\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/keras/codegemma/keras/code_gemma_1.1_2b_en/1/download/assets/tokenizer/vocabulary.spm...\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4.04M/4.04M [00:00<00:00, 7.16MB/s]\n"
          ]
        }
      ],
      "source": [
        "# Load CodeGemma\n",
        "codegemma = keras_nlp.models.GemmaCausalLM.from_preset(\"code_gemma_1.1_2b_en\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XOQZppz2OI4-"
      },
      "outputs": [],
      "source": [
        "# Tokens\n",
        "BEFORE_CURSOR = \"<|fim_prefix|>\"\n",
        "AFTER_CURSOR = \"<|fim_suffix|>\"\n",
        "AT_CURSOR = \"<|fim_middle|>\"\n",
        "FILE_SEPARATOR = \"<|file_separator|>\"\n",
        "END_TOKEN = codegemma.preprocessor.tokenizer.end_token\n",
        "stop_tokens = (BEFORE_CURSOR, AFTER_CURSOR, AT_CURSOR, FILE_SEPARATOR, END_TOKEN)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X7yYzlwSOPym"
      },
      "outputs": [],
      "source": [
        "stop_token_ids = tuple(\n",
        "    codegemma.preprocessor.tokenizer.token_to_id(x) for x in stop_tokens\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SciCjeWQB1Jm"
      },
      "source": [
        "#### Prompting example: Code infilling"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KtTQAJKUSEZ5"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "\n",
        "\n",
        "# Helpers\n",
        "def split_response_by_token(response):\n",
        "    mapping = {}\n",
        "    parts = re.split(r\"(<\\|[^\\|\\>]+\\|\\>)\", response)\n",
        "    parts = [item for item in parts if len(item)]\n",
        "    for token in stop_tokens[:3]:\n",
        "        mapping[token] = \"\"\n",
        "\n",
        "        try:\n",
        "            idx = parts.index(token)\n",
        "            if parts[idx + 1] not in stop_tokens:\n",
        "                mapping[token] = parts[idx + 1]\n",
        "        except (ValueError, IndexError):\n",
        "            pass\n",
        "\n",
        "    return mapping"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nDozO7l3hkGt"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--- RAW Response ---\n",
            "<|fim_prefix|>def calculate_area_of_rectangle(a: int, b: int) -> int:<|fim_suffix|>\n",
            "    return area<|fim_middle|>\n",
            "    \"\"\"\n",
            "    Calculate the area of a rectangle.\n",
            "\n",
            "    Args:\n",
            "        a (int): The length of the rectangle.\n",
            "        b (int): The width of the rectangle.\n",
            "\n",
            "    Returns:\n",
            "        int: The area of the rectangle.\n",
            "    \"\"\"\n",
            "    area = a * b<|file_separator|>\n",
            "\n",
            "--- The generated (FIM) piece of code: ---\n",
            "\n",
            "    \"\"\"\n",
            "    Calculate the area of a rectangle.\n",
            "\n",
            "    Args:\n",
            "        a (int): The length of the rectangle.\n",
            "        b (int): The width of the rectangle.\n",
            "\n",
            "    Returns:\n",
            "        int: The area of the rectangle.\n",
            "    \"\"\"\n",
            "    area = a * b\n",
            "\n",
            "--- The whole function: ---\n",
            "def calculate_area_of_rectangle(a: int, b: int) -> int: \n",
            "    \"\"\"\n",
            "    Calculate the area of a rectangle.\n",
            "\n",
            "    Args:\n",
            "        a (int): The length of the rectangle.\n",
            "        b (int): The width of the rectangle.\n",
            "\n",
            "    Returns:\n",
            "        int: The area of the rectangle.\n",
            "    \"\"\"\n",
            "    area = a * b \n",
            "    return area\n"
          ]
        }
      ],
      "source": [
        "prefix = \"def calculate_area_of_rectangle(a: int, b: int) -> int:\"\n",
        "suffix = \"\\n    return area\"\n",
        "prompt = f\"<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>\"\n",
        "\n",
        "response = codegemma.generate(prompt, stop_token_ids=stop_token_ids)\n",
        "parts = split_response_by_token(response)\n",
        "\n",
        "print(\"--- RAW Response ---\")\n",
        "print(response)\n",
        "\n",
        "print(\"\\n--- The generated (FIM) piece of code: ---\")\n",
        "print(parts[AT_CURSOR])\n",
        "\n",
        "print(\"\\n--- The whole function: ---\")\n",
        "print(parts[BEFORE_CURSOR], parts[AT_CURSOR], parts[AFTER_CURSOR])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ebBqyJDjB32U"
      },
      "source": [
        "#### Prompting example: Code Generation\n",
        "_Note: While the 2B version of CodeGemma was primarily designed for code completion, it can also handle basic code generation tasks. However, for optimal code generation results, we recommend using the instruction-tuned 7B model._"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QaxpsXjbgYSs"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Write a one-liner in Python that check if a year is a leap year.\n",
            "Examples:\n",
            ">>> is_a_leap_year(2016)\n",
            "True\n",
            ">>> is_a_leap_year(2001)\n",
            "False\n",
            ">>> is_a_leap_year(2052)\n",
            "True\n",
            "def is_a_leap_year(year: int) -> bool:\n",
            "    return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)\n",
            "<|file_separator|>\n"
          ]
        }
      ],
      "source": [
        "prompt = \"\"\"Write a one-liner in Python that check if a year is a leap year.\n",
        "Examples:\n",
        ">>> is_a_leap_year(2016)\n",
        "True\n",
        ">>> is_a_leap_year(2001)\n",
        "False\n",
        ">>> is_a_leap_year(2052)\n",
        "True\n",
        "def is_a_leap_year(year: int) -> bool:\"\"\"\n",
        "\n",
        "response = codegemma.generate(prompt, max_length=128)\n",
        "print(response)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qFVNs_PHOivu"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'status': 'ok', 'restart': True}"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# We need to restart the session to erase GPU RAM memory\n",
        "# and load a new model\n",
        "get_ipython().kernel.do_shutdown(True)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[CodeGemma_1]Common_use_cases.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
